import torch
import torch.nn as nn
import torch.nn.functional as F
def flatten_feat(x):
    return x.reshape(x.size(0), -1)
def compute_dynamic_lambda(current_epoch, total_epochs, lambda_start=0.01, lambda_end=0.1):
    import math
    progress = current_epoch / max(total_epochs - 1, 1)
    current_lambda = lambda_start + (lambda_end - lambda_start) * (1 - math.cos(progress * math.pi)) / 2
    return current_lambda
def relaxed_contrastive_loss(A, B, lambda_weight=0.6, current_epoch=0, total_epochs=100,
                           use_lambda_schedule=False, lambda_start=0.01, lambda_end=0.1):
    if use_lambda_schedule:
        current_lambda = compute_dynamic_lambda(current_epoch, total_epochs, lambda_start, lambda_end)
    else:
        current_lambda = lambda_weight
    A_flat = flatten_feat(A)
    B_flat = flatten_feat(B)
    A_norm = F.normalize(A_flat, p=2, dim=1)
    B_norm = F.normalize(B_flat, p=2, dim=1)
    sim_matrix = torch.mm(A_norm, B_norm.T)
    diagonal = torch.diag(sim_matrix)
    positive_term = ((1.0 - diagonal) ** 2).sum()
    n = sim_matrix.size(0)
    off_diagonal_mask = ~torch.eye(n, dtype=torch.bool, device=sim_matrix.device)
    off_diagonal = sim_matrix[off_diagonal_mask]
    positive_off_diagonal = torch.clamp(off_diagonal, min=0.0)
    negative_term = (positive_off_diagonal ** 2).sum()
    total_loss = positive_term + current_lambda * negative_term
    return total_loss / (n * n)
def info_nce_loss(A, B, temperature=0.1):
    A_flat = flatten_feat(A)
    B_flat = flatten_feat(B)
    A_norm = F.normalize(A_flat, dim=1)
    B_norm = F.normalize(B_flat, dim=1)
    logits = torch.mm(A_norm, B_norm.T) / temperature
    labels = torch.arange(A.size(0), device=A.device)
    return F.cross_entropy(logits, labels)
def variance_loss(z, gamma=1.0):
    std = torch.sqrt(z.var(dim=0) + 1e-04)
    return torch.mean(F.relu(gamma - std))
def covariance_loss(z):
    n, d = z.shape
    z = z - z.mean(dim=0, keepdim=True)
    z_norm = F.normalize(z, p=2, dim=1)
    corr = torch.mm(z_norm, z_norm.T)
    eye = torch.eye(n, device=z.device)
    off_diag_corr = corr * (1 - eye)
    return (off_diag_corr ** 2).sum() / n
def create_frequency_filters(shape, cutoff_ratio=0.3):
    _, _, H, W = shape
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    freq_h = torch.fft.fftfreq(H, device=device)
    freq_w = torch.fft.fftfreq(W, device=device)
    freq_grid_h, freq_grid_w = torch.meshgrid(freq_h, freq_w, indexing='ij')
    freq_distance = torch.sqrt(freq_grid_h**2 + freq_grid_w**2)
    max_freq = 0.5
    normalized_distance = freq_distance / max_freq
    low_pass_mask = (normalized_distance <= cutoff_ratio).float()
    high_pass_mask = (normalized_distance > cutoff_ratio).float()
    low_pass_mask = low_pass_mask.unsqueeze(0).unsqueeze(0)
    high_pass_mask = high_pass_mask.unsqueeze(0).unsqueeze(0)
    return low_pass_mask, high_pass_mask
def rgb_to_lab_torch(rgb_tensor):
    rgb = rgb_tensor.clamp(0, 1)
    eps = 1e-8
    gamma_threshold = 0.04045
    linear_rgb = torch.where(
        rgb <= gamma_threshold,
        rgb / 12.92,
        torch.pow((rgb + 0.055) / 1.055, 2.4)
    )
    rgb_to_xyz_matrix = torch.tensor([
        [0.4124564, 0.3575761, 0.1804375],
        [0.2126729, 0.7151522, 0.0721750],
        [0.0193339, 0.1191920, 0.9503041]
    ], device=rgb_tensor.device, dtype=rgb_tensor.dtype)
    N, C, H, W = rgb_tensor.shape
    rgb_flat = linear_rgb.permute(0, 2, 3, 1).reshape(-1, 3)
    xyz_flat = torch.mm(rgb_flat, rgb_to_xyz_matrix.t())
    xyz = xyz_flat.reshape(N, H, W, 3).permute(0, 3, 1, 2)
    xyz_normalized = xyz / (torch.tensor([0.95047, 1.00000, 1.08883], device=xyz.device).view(1, 3, 1, 1) + eps)
    xyz_normalized = torch.clamp(xyz_normalized, 0, 10)
    epsilon = 0.008856
    kappa = 903.3
    f_xyz = torch.where(
        xyz_normalized > epsilon,
        torch.pow(torch.clamp(xyz_normalized, eps, 1), 1/3),
        (kappa * xyz_normalized + 16) / 116
    )
    L = torch.clamp(116 * f_xyz[:, 1:2] - 16, 0, 100)
    a = torch.clamp(500 * (f_xyz[:, 0:1] - f_xyz[:, 1:2]), -128, 127)
    b = torch.clamp(200 * (f_xyz[:, 1:2] - f_xyz[:, 2:3]), -128, 127)
    L = L / 100.0
    a = (a + 128.0) / 255.0
    b = (b + 128.0) / 255.0
    lab = torch.cat([L, a, b], dim=1)
    lab = torch.nan_to_num(lab, nan=0.5, posinf=1.0, neginf=0.0)
    lab = torch.clamp(lab, 0, 1)
    return lab
def frequency_aware_relaxed_contrastive_loss(A, B, cutoff_ratio=0.3,
                                           low_freq_lambda=0.02, high_freq_lambda=0.15,
                                           current_epoch=0, total_epochs=100,
                                           use_lambda_schedule=False, lambda_start=0.01, lambda_end=0.1,
                                           use_color_space_separation=False, luminance_weight=0.7, chrominance_weight=0.3):
    N, C, H, W = A.shape
    device = A.device
    if use_color_space_separation and C == 3:
        A_lab = rgb_to_lab_torch(A)
        B_lab = rgb_to_lab_torch(B)
        A_L, A_ab = A_lab[:, [0]], A_lab[:, 1:]
        B_L, B_ab = B_lab[:, [0]], B_lab[:, 1:]
        A_L_fft = torch.fft.fft2(A_L, dim=(-2, -1))
        B_L_fft = torch.fft.fft2(B_L, dim=(-2, -1))
        low_pass_mask_L, high_pass_mask_L = create_frequency_filters(A_L.shape, cutoff_ratio)
        low_pass_mask_L = low_pass_mask_L.to(device)
        high_pass_mask_L = high_pass_mask_L.to(device)
        A_L_low = torch.fft.ifft2(A_L_fft * low_pass_mask_L, dim=(-2, -1)).real
        B_L_low = torch.fft.ifft2(B_L_fft * low_pass_mask_L, dim=(-2, -1)).real
        A_L_high = torch.fft.ifft2(A_L_fft * high_pass_mask_L, dim=(-2, -1)).real
        B_L_high = torch.fft.ifft2(B_L_fft * high_pass_mask_L, dim=(-2, -1)).real
        L_low_loss = relaxed_contrastive_loss(A_L_low, B_L_low, lambda_weight=low_freq_lambda,
                                            current_epoch=current_epoch, total_epochs=total_epochs,
                                            use_lambda_schedule=use_lambda_schedule, lambda_start=lambda_start, lambda_end=lambda_end)
        L_high_loss = relaxed_contrastive_loss(A_L_high, B_L_high, lambda_weight=high_freq_lambda,
                                             current_epoch=current_epoch, total_epochs=total_epochs,
                                             use_lambda_schedule=use_lambda_schedule, lambda_start=lambda_start, lambda_end=lambda_end)
        L_loss = L_low_loss + L_high_loss
        A_ab_fft = torch.fft.fft2(A_ab, dim=(-2, -1))
        B_ab_fft = torch.fft.fft2(B_ab, dim=(-2, -1))
        low_pass_mask_ab, high_pass_mask_ab = create_frequency_filters(A_ab.shape, cutoff_ratio)
        low_pass_mask_ab = low_pass_mask_ab.to(device)
        high_pass_mask_ab = high_pass_mask_ab.to(device)
        A_ab_low = torch.fft.ifft2(A_ab_fft * low_pass_mask_ab, dim=(-2, -1)).real
        B_ab_low = torch.fft.ifft2(B_ab_fft * low_pass_mask_ab, dim=(-2, -1)).real
        A_ab_high = torch.fft.ifft2(A_ab_fft * high_pass_mask_ab, dim=(-2, -1)).real
        B_ab_high = torch.fft.ifft2(B_ab_fft * high_pass_mask_ab, dim=(-2, -1)).real
        ab_low_loss = relaxed_contrastive_loss(A_ab_low, B_ab_low, lambda_weight=low_freq_lambda * 0.6,
                                             current_epoch=current_epoch, total_epochs=total_epochs,
                                             use_lambda_schedule=use_lambda_schedule, lambda_start=lambda_start, lambda_end=lambda_end)
        ab_high_loss = relaxed_contrastive_loss(A_ab_high, B_ab_high, lambda_weight=high_freq_lambda * 0.6,
                                              current_epoch=current_epoch, total_epochs=total_epochs,
                                              use_lambda_schedule=use_lambda_schedule, lambda_start=lambda_start, lambda_end=lambda_end)
        ab_loss = ab_low_loss + ab_high_loss
        return luminance_weight * L_loss + chrominance_weight * ab_loss
    else:
        A_fft = torch.fft.fft2(A, dim=(-2, -1))
        B_fft = torch.fft.fft2(B, dim=(-2, -1))
        low_pass_mask, high_pass_mask = create_frequency_filters(A.shape, cutoff_ratio)
        low_pass_mask = low_pass_mask.to(device)
        high_pass_mask = high_pass_mask.to(device)
        A_low_fft = A_fft * low_pass_mask
        B_low_fft = B_fft * low_pass_mask
        A_high_fft = A_fft * high_pass_mask
        B_high_fft = B_fft * high_pass_mask
        A_low = torch.fft.ifft2(A_low_fft, dim=(-2, -1)).real
        B_low = torch.fft.ifft2(B_low_fft, dim=(-2, -1)).real
        A_high = torch.fft.ifft2(A_high_fft, dim=(-2, -1)).real
        B_high = torch.fft.ifft2(B_high_fft, dim=(-2, -1)).real
        low_freq_loss = relaxed_contrastive_loss(
            A_low, B_low,
            lambda_weight=low_freq_lambda,
            current_epoch=current_epoch,
            total_epochs=total_epochs,
            use_lambda_schedule=use_lambda_schedule,
            lambda_start=lambda_start,
            lambda_end=lambda_end
        )
        high_freq_loss = relaxed_contrastive_loss(
            A_high, B_high,
            lambda_weight=high_freq_lambda,
            current_epoch=current_epoch,
            total_epochs=total_epochs,
            use_lambda_schedule=use_lambda_schedule,
            lambda_start=lambda_start,
            lambda_end=lambda_end
        )
        return low_freq_loss + high_freq_loss
def three_term_loss(A, B, inv_w=2.0, var_w=2.0, cov_w=1.0, gamma=1.0,
                   current_epoch=0, total_epochs=100, infonce_temperature=0.1,
                   mse_start_ratio=1.0, mse_end_ratio=0.0,
                   use_relaxed_contrastive=False, reco_lambda=0.6,
                   use_lambda_schedule=False, lambda_start=0.01, lambda_end=0.1):
    progress = current_epoch / max(total_epochs - 1, 1)
    mse_ratio = mse_start_ratio + (mse_end_ratio - mse_start_ratio) * progress
    infonce_ratio = (1.0 - mse_start_ratio) + (mse_start_ratio - mse_end_ratio) * progress
    w_mse = inv_w * mse_ratio
    w_infonce = inv_w * infonce_ratio
    ZA = flatten_feat(A)
    ZB = flatten_feat(B)
    if ZA.shape == ZB.shape:
        mse_loss = F.mse_loss(ZA, ZB)
    else:
        mse_loss = torch.tensor(0.0, device=ZA.device, requires_grad=True)
        w_infonce = inv_w
        w_mse = 0.0
    if use_relaxed_contrastive:
        contrastive_loss = relaxed_contrastive_loss(A, B, lambda_weight=reco_lambda,
                                                  current_epoch=current_epoch, total_epochs=total_epochs,
                                                  use_lambda_schedule=use_lambda_schedule,
                                                  lambda_start=lambda_start, lambda_end=lambda_end)
    else:
        contrastive_loss = info_nce_loss(A, B, temperature=infonce_temperature)
    inv_loss = w_mse * mse_loss + w_infonce * contrastive_loss
    var_loss = variance_loss(ZA, gamma) + variance_loss(ZB, gamma)
    cov_loss = covariance_loss(ZA) + covariance_loss(ZB)
    return inv_loss + var_w * var_loss + cov_w * cov_loss
class C2Loss_ConvAutoEncoder_Test(nn.Module):
    def __init__(self):
        super(C2Loss_ConvAutoEncoder_Test, self).__init__()
    def forward(self, activations, signals, target, method="local", args=None):
        layer_losses = []
        if method == "local":
            use_three_term = True
            use_relaxed_contrastive = getattr(args, 'use_relaxed_contrastive', True) if args else True
            reco_lambda = getattr(args, 'reco_lambda', 0.6) if args else 0.6
            use_lambda_schedule = getattr(args, 'use_lambda_schedule', False) if args else False
            lambda_start = getattr(args, 'lambda_start', 0.01) if args else 0.01
            lambda_end = getattr(args, 'lambda_end', 0.1) if args else 0.1
            use_frequency_loss = getattr(args, 'use_frequency_aware_loss', False) if args else False
            freq_cutoff_ratio = getattr(args, 'freq_cutoff_ratio', 0.3) if args else 0.3
            freq_low_lambda = getattr(args, 'freq_low_lambda', 0.02) if args else 0.02
            freq_high_lambda = getattr(args, 'freq_high_lambda', 0.15) if args else 0.15
            loss_scale_C = getattr(args, 'loss_scale_C', 1.0) if args else 1.0
            loss_scale_D = getattr(args, 'loss_scale_D', 1.0) if args else 1.0
            loss_scale_ssl = getattr(args, 'loss_scale_ssl', 0.5) if args else 0.5
            inv_w = getattr(args, 'inv_w', 2.0) if args else 2.0
            var_w = getattr(args, 'var_w', 2.0) if args else 2.0
            cov_w = getattr(args, 'cov_w', 1.0) if args else 1.0
            gamma = getattr(args, 'gamma', 1.0) if args else 1.0
            mse_start_ratio = getattr(args, 'mse_start_ratio', 1.0) if args else 1.0
            mse_end_ratio = getattr(args, 'mse_end_ratio', 0.0) if args else 0.0
            current_epoch = getattr(args, 'current_epoch', 0) if args else 0
            total_epochs = getattr(args, 'total_epochs', 100) if args else 100
            for i in range(len(activations)):
                A = activations[i]
                B = signals[i]
                if len(A.shape) == 5:
                    A = A.mean(dim=0)
                if len(B.shape) == 5:
                    B = B.mean(dim=0)
                if use_three_term:
                    use_final_reco = getattr(args, 'use_final_layer_reco', False) if args else False
                    final_lambda = getattr(args, 'final_layer_lambda', 0.05) if args else 0.05
                    if i == len(activations) - 1 and not use_final_reco:
                        loss = F.mse_loss(A, target) * A.shape[2]
                    elif i == len(activations) - 1 and use_final_reco:
                        if use_frequency_loss and len(A.shape) == 4 and len(B.shape) == 4:
                            loss = frequency_aware_relaxed_contrastive_loss(
                                A, B,
                                cutoff_ratio=freq_cutoff_ratio,
                                low_freq_lambda=final_lambda * 0.5,
                                high_freq_lambda=final_lambda * 1.0,
                                current_epoch=current_epoch,
                                total_epochs=total_epochs,
                                use_lambda_schedule=use_lambda_schedule,
                                lambda_start=lambda_start,
                                lambda_end=lambda_end,
                                use_color_space_separation=False
                            ) * inv_w
                        else:
                            if use_relaxed_contrastive:
                                loss = relaxed_contrastive_loss(A, B, lambda_weight=final_lambda,
                                                               current_epoch=current_epoch, total_epochs=total_epochs,
                                                               use_lambda_schedule=use_lambda_schedule, lambda_start=lambda_start, lambda_end=lambda_end) * inv_w
                            else:
                                loss = infonce_loss(A, B, temperature=infonce_temperature) * inv_w
                    else:
                        if use_frequency_loss and len(A.shape) == 4 and len(B.shape) == 4:
                            use_color_space = getattr(args, 'use_color_space_separation', False) if args else False
                            luminance_weight = getattr(args, 'luminance_weight', 0.7) if args else 0.7
                            chrominance_weight = getattr(args, 'chrominance_weight', 0.3) if args else 0.3
                            freq_loss = frequency_aware_relaxed_contrastive_loss(
                                A, B,
                                cutoff_ratio=freq_cutoff_ratio,
                                low_freq_lambda=freq_low_lambda,
                                high_freq_lambda=freq_high_lambda,
                                current_epoch=current_epoch,
                                total_epochs=total_epochs,
                                use_lambda_schedule=use_lambda_schedule,
                                lambda_start=lambda_start,
                                lambda_end=lambda_end,
                                use_color_space_separation=use_color_space,
                                luminance_weight=luminance_weight,
                                chrominance_weight=chrominance_weight
                            )
                            A_flat = flatten_feat(A)
                            B_flat = flatten_feat(B)
                            var_reg = variance_loss(A_flat, gamma) + variance_loss(B_flat, gamma)
                            cov_reg = covariance_loss(A_flat) + covariance_loss(B_flat)
                            loss = inv_w * freq_loss + var_w * var_reg + cov_w * cov_reg
                        else:
                            loss = three_term_loss(
                                A, B,
                                inv_w=inv_w,
                                var_w=var_w,
                                cov_w=cov_w,
                                gamma=gamma,
                                current_epoch=current_epoch,
                                total_epochs=total_epochs,
                                mse_start_ratio=mse_start_ratio,
                                mse_end_ratio=mse_end_ratio,
                                use_relaxed_contrastive=use_relaxed_contrastive,
                                reco_lambda=reco_lambda,
                                use_lambda_schedule=use_lambda_schedule,
                                lambda_start=lambda_start,
                                lambda_end=lambda_end
                            )
                else:
                    A_flat = A.view(A.size(0), -1)
                    B_flat = B.view(B.size(0), -1)
                    loss_C = F.mse_loss(A_flat, B_flat) * A.shape[2] * loss_scale_C
                    A_norm = (A_flat - A_flat.mean(dim=1, keepdim=True)) / (A_flat.norm(dim=1, keepdim=True) + 1e-8)
                    B_norm = (B_flat - B_flat.mean(dim=1, keepdim=True)) / (B_flat.norm(dim=1, keepdim=True) + 1e-8)
                    loss_D = F.mse_loss(A_norm, B_norm) * A.shape[2] * loss_scale_D
                    use_final_reco = getattr(args, 'use_final_layer_reco', False) if args else False
                    final_lambda = getattr(args, 'final_layer_lambda', 0.05) if args else 0.05
                    if i == 0:
                        loss = loss_C
                    elif i == len(activations) - 1 and not use_final_reco:
                        loss = F.mse_loss(A, target) * A.shape[2]
                    elif i == len(activations) - 1 and use_final_reco:
                        if use_frequency_loss and len(A.shape) == 4 and len(B.shape) == 4:
                            loss = frequency_aware_relaxed_contrastive_loss(
                                A, B,
                                cutoff_ratio=freq_cutoff_ratio,
                                low_freq_lambda=final_lambda * 0.5,
                                high_freq_lambda=final_lambda * 1.0,
                                current_epoch=current_epoch,
                                total_epochs=total_epochs,
                                use_lambda_schedule=use_lambda_schedule,
                                lambda_start=lambda_start,
                                lambda_end=lambda_end,
                                use_color_space_separation=False
                            ) * loss_scale_C
                        else:
                            loss = infonce_loss(A, B, temperature=infonce_temperature) * loss_scale_C
                    else:
                        loss = loss_C + loss_D
                layer_losses.append(loss)
            total_loss = sum(layer_losses)
        else:
            reconstruction = activations[-1]
            if len(reconstruction.shape) == 5:
                reconstruction = reconstruction.mean(dim=0)
            if len(target.shape) == 5:
                target = target.mean(dim=0)
            total_loss = F.binary_cross_entropy(reconstruction, target)
            layer_losses = [total_loss]
        return total_loss, layer_losses
def gradient_centralization(model):
    with torch.no_grad():
        for p1, p2 in model.named_parameters():
            if "bias" in p1 or p2.grad is None: continue
            if len(p2.shape) == 2: p2.grad -= p2.grad.mean(dim=1,keepdim=True)
            elif len(p2.shape) == 4: p2.grad -= p2.grad.mean(dim=[1,2,3],keepdim=True)
def normalize_along_axis(x, norm=True):
    x = x.reshape(len(x), -1)
    if norm:
        norm = torch.norm(x, dim=1, keepdim=True)
    return x / (norm + 1e-8)
def compute_scl_loss(args, A, B):
    loss1 = F.mse_loss(A, B) * A.shape[2] * args.loss_scale_C
    return loss1
class C2Loss_AE(nn.Module):
    def __init__(self, args):
        super(C2Loss_AE, self).__init__()
        self.args = args
        self.final_criteria = nn.CrossEntropyLoss()
        self.local_criteria = compute_scl_loss
        self.method = args.method
    def forward(self, activations, signals, target, method="final"):
        if method == "local":
            loss = list()
            for idx, (act, sig) in enumerate(zip(activations, signals)):
                if len(act.shape) == 4 and len(sig.shape) == 2: sig = sig.view(sig.shape[0], sig.shape[1], act.shape[2], act.shape[3])
                if len(act.shape) == 2 and len(sig.shape) == 4: act = act.view(act.shape[0], act.shape[1], sig.shape[2], sig.shape[3])
                loss += [self.local_criteria(self.args, act, sig)]
            return sum(loss), loss[-1].item()
        elif method == "final":
            loss = self.final_criteria(activations[-1], target)
            return loss, loss.item()
def compute_scl_loss_classification(args, A, B, target, predictions):
    if args.filter_target != 0:
        with torch.no_grad():
            if hasattr(args, 'use_label_encoding') and getattr(args, 'use_label_encoding', False):
                if hasattr(args, 'label_encodings') and args.label_encodings is not None:
                    batch_size = predictions.shape[0]
                    mask = torch.ones(batch_size, device=predictions.device)
                    for i in range(batch_size):
                        pred_encoding = predictions[i]
                        target_class = target[i].item()
                        target_encoding = args.label_encodings[target_class]
                        pred_flat = pred_encoding.flatten()
                        target_flat = target_encoding.flatten()
                        similarity = F.cosine_similarity(pred_flat.unsqueeze(0), target_flat.unsqueeze(0), dim=1)
                        if similarity.item() > 0.9:
                            mask[i] = 0.0
                        else:
                            mask[i] = 1.0
                else:
                    mask = torch.ones(predictions.shape[0], device=predictions.device)
            else:
                softmax_output = F.softmax(predictions, dim=1)
                target_temp = torch.zeros_like(softmax_output, dtype=torch.float32)
                target_temp.scatter_(1, target.unsqueeze(1), 1.0)
                diff = torch.abs(softmax_output - target_temp)
                mask = 1 - torch.sum((diff < 0.1) * target_temp, dim=1).to(torch.float32)
        if len(A.shape) == 2:
            A = A * mask.unsqueeze(1)
        elif len(A.shape) == 3:
            A = A * mask.view(mask.shape[0], 1, 1)
        elif len(A.shape) == 4:
            A = A * mask.view(mask.shape[0], 1, 1, 1)
        elif len(A.shape) == 5:
            A = A * mask.view(mask.shape[0], 1, 1, 1, 1)
        else:
            raise ValueError(f"Unsupported shape for A: {A.shape}")
    A = A.view(A.shape[0], -1)
    B = B.view(B.shape[0], -1)
    C = A@B.T
    batch_size = A.shape[0]
    identity = torch.eye(batch_size, dtype=torch.float32, device=target.device)
    loss1 = F.mse_loss(C, identity) / A.shape[1] * args.loss_scale_C * 256
    return loss1
class C2Loss_Classification(nn.Module):
    def __init__(self, args):
        super(C2Loss_Classification, self).__init__()
        self.args = args
        self.final_criteria = nn.CrossEntropyLoss()
        self.local_criteria = compute_scl_loss_classification
        self.method = args.method
    def compute_label_encoding_final_loss(self, final_output, target):
        from utils.label_encoding import compute_cosine_similarity_accuracy
        batch_size, T, L = final_output.shape
        total_loss = 0.0
        if not hasattr(self.args, 'label_encodings') or self.args.label_encodings is None:
            target_one_hot = F.one_hot(target, num_classes=L).float()
            target_expanded = target_one_hot.unsqueeze(1).repeat(1, T, 1)
            return F.mse_loss(final_output, target_expanded)
        label_encodings = self.args.label_encodings
        for i in range(batch_size):
            sample_output = final_output[i]
            sample_target_class = target[i].item()
            target_encoding = label_encodings[sample_target_class]
            sample_loss = F.mse_loss(sample_output, target_encoding)
            total_loss += sample_loss
        return total_loss / batch_size
    def forward(self, activations, signals, target, method="final"):
        if method == "local":
            loss = list()
            for idx, (act, sig) in enumerate(zip(activations[:-1], signals[:-1])):
                if len(act.shape) == 4 and len(sig.shape) == 2: sig = sig.view(sig.shape[0], sig.shape[1], act.shape[2], act.shape[3])
                if len(act.shape) == 2 and len(sig.shape) == 4: act = act.view(act.shape[0], act.shape[1], sig.shape[2], sig.shape[3])
                loss += [self.local_criteria(self.args, act, sig, target, activations[-1])]
            if hasattr(self.args, 'use_label_encoding') and getattr(self.args, 'use_label_encoding', False):
                final_loss = self.compute_label_encoding_final_loss(activations[-1], target)
                loss += [final_loss]
            else:
                loss += [self.final_criteria(activations[-1], target)]
            return sum(loss), loss[-1].item()
        elif method == "final":
            loss = self.final_criteria(activations[-1], target)
            return loss, loss.item()
def compute_scl_loss_classification_cnn_test(args, A, B, target, predictions):
    if args.filter_target != 0:
        with torch.no_grad():
            if hasattr(args, 'use_label_encoding') and getattr(args, 'use_label_encoding', False):
                if hasattr(args, 'label_encodings') and args.label_encodings is not None:
                    batch_size = predictions.shape[0]
                    mask = torch.ones(batch_size, device=predictions.device)
                    for i in range(batch_size):
                        pred_encoding = predictions[i]
                        target_class = target[i].item()
                        target_encoding = args.label_encodings[target_class]
                        pred_flat = pred_encoding.flatten()
                        target_flat = target_encoding.flatten()
                        similarity = F.cosine_similarity(pred_flat.unsqueeze(0), target_flat.unsqueeze(0), dim=1)
                        if similarity.item() > 0.9:
                            mask[i] = 0.0
                        else:
                            mask[i] = 1.0
                else:
                    mask = torch.ones(predictions.shape[0], device=predictions.device)
            else:
                softmax_output = F.softmax(predictions, dim=1)
                target_temp = torch.zeros_like(softmax_output, dtype=torch.float32)
                target_temp.scatter_(1, target.unsqueeze(1), 1.0)
                diff = torch.abs(softmax_output - target_temp)
                mask = 1 - torch.sum((diff < 0.1) * target_temp, dim=1).to(torch.float32)
        if len(A.shape) == 2:
            A = A * mask.unsqueeze(1)
        elif len(A.shape) == 3:
            A = A * mask.view(mask.shape[0], 1, 1)
        elif len(A.shape) == 4:
            A = A * mask.view(mask.shape[0], 1, 1, 1)
        elif len(A.shape) == 5:
            A = A * mask.view(mask.shape[0], 1, 1, 1, 1)
        else:
            raise ValueError(f"Unsupported shape for A: {A.shape}")
    d = A.shape[1]
    A = normalize_along_axis(A)
    B = normalize_along_axis(B)
    C = A@B.T
    batch_size = A.shape[0]
    identity = torch.eye(batch_size, dtype=torch.float32, device=target.device)
    loss1 = F.mse_loss(C, identity) * args.loss_scale_C
    identity = torch.eye(A.shape[0]).to(A.device)
    D = torch.matmul(A, A.T)
    loss2 = F.mse_loss(D, identity) * args.loss_scale_ssl
    identity = torch.eye(B.shape[0]).to(B.device)
    E = torch.matmul(B, B.T)
    loss3 = F.mse_loss(E, identity) * args.loss_scale_ssl
    return loss1 + loss2 + loss3
class C2Loss_Classification_CNN_Test(nn.Module):
    def __init__(self, args):
        super(C2Loss_Classification_CNN_Test, self).__init__()
        self.args = args
        self.final_criteria = nn.CrossEntropyLoss()
        self.local_criteria = compute_scl_loss_classification_cnn_test
        self.method = args.method
        self.use_three_term = getattr(args, 'use_three_term_loss', False)
        if self.use_three_term:
            self.three_term_inv_w = getattr(args, 'three_term_inv_weight', 2.0)
            self.three_term_var_w = getattr(args, 'three_term_var_weight', 2.0)
            self.three_term_cov_w = getattr(args, 'three_term_cov_weight', 1.0)
            self.three_term_gamma = getattr(args, 'three_term_gamma', 1.0)
            self.three_term_temperature = getattr(args, 'three_term_infonce_temperature', 0.1)
            self.three_term_mse_start_ratio = getattr(args, 'three_term_mse_start_ratio', 1.0)
            self.three_term_mse_end_ratio = getattr(args, 'three_term_mse_end_ratio', 0.0)
            self.use_relaxed_contrastive = getattr(args, 'three_term_use_relaxed_contrastive', False)
            self.reco_lambda = getattr(args, 'three_term_reco_lambda', 0.6)
            contrastive_type = "ReCo (Relaxed Contrastive)" if self.use_relaxed_contrastive else "InfoNCE"
            print(f"🔥 Using Three-Term Loss: inv_w={self.three_term_inv_w}, var_w={self.three_term_var_w}, cov_w={self.three_term_cov_w}, gamma={self.three_term_gamma}")
            if self.use_relaxed_contrastive:
                print(f"   Contrastive: {contrastive_type} (λ={self.reco_lambda})")
            else:
                print(f"   Contrastive: {contrastive_type} (temp={self.three_term_temperature})")
            print(f"   MSE/Contrastive Scheduling: {self.three_term_mse_start_ratio:.1f}→{self.three_term_mse_end_ratio:.1f} (MSE ratio)")
        else:
            print("🔹 Using Original BSD Loss (matrix orthogonal constraint)")
    def compute_label_encoding_final_loss(self, final_output, target):
        from utils.label_encoding import compute_cosine_similarity_accuracy
        batch_size, T, L = final_output.shape
        total_loss = 0.0
        if not hasattr(self.args, 'label_encodings') or self.args.label_encodings is None:
            target_one_hot = F.one_hot(target, num_classes=L).float()
            target_expanded = target_one_hot.unsqueeze(1).repeat(1, T, 1)
            return F.mse_loss(final_output, target_expanded)
        label_encodings = self.args.label_encodings
        for i in range(batch_size):
            sample_output = final_output[i]
            sample_target_class = target[i].item()
            target_encoding = label_encodings[sample_target_class]
            sample_loss = F.mse_loss(sample_output, target_encoding)
            total_loss += sample_loss
        return total_loss / batch_size
    def forward(self, activations, signals, target, method="final", current_epoch=0, total_epochs=100):
        if method == "local":
            loss = list()
            for idx, (act, sig) in enumerate(zip(activations[:-1], signals[:-1])):
                if len(act.shape) == 4 and len(sig.shape) == 2:
                    sig = sig.view(sig.shape[0], sig.shape[1], act.shape[2], act.shape[3])
                if len(act.shape) == 2 and len(sig.shape) == 4:
                    act = act.view(act.shape[0], act.shape[1], sig.shape[2], sig.shape[3])
                if len(act.shape) == 5 and len(sig.shape) == 3:
                    sig = sig.view(sig.shape[0], sig.shape[1], sig.shape[2], act.shape[3], act.shape[4])
                if self.use_three_term:
                    layer_loss = three_term_loss(act, sig,
                                           inv_w=self.three_term_inv_w,
                                           var_w=self.three_term_var_w,
                                           cov_w=self.three_term_cov_w,
                                           gamma=self.three_term_gamma,
                                           current_epoch=current_epoch,
                                           total_epochs=total_epochs,
                                           infonce_temperature=self.three_term_temperature,
                                           mse_start_ratio=self.three_term_mse_start_ratio,
                                           mse_end_ratio=self.three_term_mse_end_ratio,
                                           use_relaxed_contrastive=self.use_relaxed_contrastive,
                                           reco_lambda=self.reco_lambda)
                else:
                    layer_loss = self.local_criteria(self.args, act, sig, target, activations[-1])
                loss += [layer_loss]
            if hasattr(self.args, 'use_label_encoding') and getattr(self.args, 'use_label_encoding', False):
                final_loss = self.compute_label_encoding_final_loss(activations[-1], target)
                loss += [final_loss]
            else:
                loss += [self.final_criteria(activations[-1], target)]
            return sum(loss), loss[-1].item()
        elif method == "final":
            loss = self.final_criteria(activations[-1], target)
            return loss, loss.item()
def compute_scl_loss_convautoencoder(args, A, B, target, predictions):
    loss_C = F.mse_loss(A, B) * 1
    return loss_C
class C2Loss_ConvAutoencoder(nn.Module):
    def __init__(self, args):
        super(C2Loss_ConvAutoencoder, self).__init__()
        self.args = args
        self.final_criteria = nn.CrossEntropyLoss()
        self.local_criteria = compute_scl_loss_convautoencoder
        self.method = args.method
    def forward(self, activations, signals, target):
        loss = list()
        for idx, (act, sig) in enumerate(zip(activations, signals)):
            if len(act.shape) == 4 and len(sig.shape) == 2: sig = sig.view(sig.shape[0], sig.shape[1], act.shape[2], act.shape[3])
            if len(act.shape) == 2 and len(sig.shape) == 4: act = act.view(act.shape[0], act.shape[1], sig.shape[2], sig.shape[3])
            loss += [self.local_criteria(self.args, act, sig, target, activations[-1])]
        return sum(loss), loss[-1].item()
